{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "#!wget http://baidudeeplearning.bj.bcebos.com/image_contest_level_1.tar.gz\n",
    "#!wget http://baidudeeplearning.bj.bcebos.com/image_contest_level_1_validate.tar.gz\n",
    "#!wget http://baidudeeplearning.bj.bcebos.com/image_contest_level_2.tar.gz\n",
    "#!wget http://baidudeeplearning.bj.bcebos.com/image_contest_level_2_validate.tar.gz\n",
    "#!tar -zxvf ./image_contest_level_1.tar.gz\n",
    "#!tar -zxvf ./image_contest_level_1_validate.tar.gz\n",
    "#!tar -zxvf ./image_contest_level_2.tar.gz\n",
    "#!tar -zxvf ./image_contest_level_2_validate.tar.gz"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 生成验证码"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from captcha.image import ImageCaptcha\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import random\n",
    "\n",
    "%matplotlib inline\n",
    "%config InlineBackend.figure_format = 'retina'\n",
    "\n",
    "import string\n",
    "characters = string.digits + string.ascii_uppercase\n",
    "print(characters)\n",
    "\n",
    "width, height, n_len, n_class = 170, 80, 4, len(characters)\n",
    "\n",
    "generator = ImageCaptcha(width=width, height=height)\n",
    "random_str = ''.join([random.choice(characters) for j in range(4)])\n",
    "img = generator.generate_image(random_str)\n",
    "\n",
    "plt.imshow(img)\n",
    "plt.title(random_str)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 生成器"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def gen(batch_size=32):\n",
    "    X = np.zeros((batch_size, height, width, 3), dtype=np.uint8)\n",
    "    y = [np.zeros((batch_size, n_class), dtype=np.uint8) for i in range(n_len)]\n",
    "    generator = ImageCaptcha(width=width, height=height)\n",
    "    while True:\n",
    "        for i in range(batch_size):\n",
    "            random_str = ''.join([random.choice(characters) for j in range(4)])\n",
    "            X[i] = generator.generate_image(random_str)\n",
    "            for j, ch in enumerate(random_str):\n",
    "                y[j][i, :] = 0\n",
    "                y[j][i, characters.find(ch)] = 1\n",
    "        yield X, y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def decode(y, index=0):\n",
    "    y = np.argmax(np.array(y), axis=2)[:, index]\n",
    "    return ''.join([characters[x] for x in y])\n",
    "\n",
    "X, y = next(gen(1))\n",
    "plt.imshow(X[0])\n",
    "plt.title(decode(y))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 搭建模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from keras.models import *\n",
    "from keras.layers import *\n",
    "from keras.optimizers import *\n",
    "\n",
    "input_tensor = Input((height, width, 3))\n",
    "x = input_tensor\n",
    "for i in range(4):\n",
    "    x = Convolution2D(32*2**i, 3, activation='relu')(x)\n",
    "    x = Convolution2D(32*2**i, 3, activation='relu')(x)\n",
    "    x = MaxPooling2D(2)(x)\n",
    "\n",
    "x = Flatten()(x)\n",
    "x = Dropout(0.25)(x)\n",
    "x = [Dense(n_class, activation='softmax', name='c%d'%(i+1))(x) for i in range(4)]\n",
    "model = Model(inputs=input_tensor, outputs=x)\n",
    "\n",
    "model.compile(loss='categorical_crossentropy',\n",
    "              optimizer='adadelta',\n",
    "              metrics=['accuracy'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from keras.utils.vis_utils import plot_model\n",
    "from IPython.display import Image\n",
    "\n",
    "plot_model(model, to_file=\"model.png\", show_shapes=True)\n",
    "Image('model.png')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 训练模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "h = model.fit_generator(gen(128), steps_per_epoch=400, epochs=20, \n",
    "                        workers=4, pickle_safe=True, \n",
    "                        validation_data=gen(128), validation_steps=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10, 4))\n",
    "plt.subplot(1, 2, 1)\n",
    "plt.plot(h.history['loss'])\n",
    "plt.plot(h.history['val_loss'])\n",
    "plt.legend(['loss', 'val_loss'])\n",
    "plt.ylabel('loss')\n",
    "plt.xlabel('epoch')\n",
    "plt.ylim(0, 1)\n",
    "\n",
    "plt.subplot(1, 2, 2)\n",
    "for i in range(4):\n",
    "    plt.plot(h.history['val_c%d_acc' % (i+1)])\n",
    "plt.legend(['val_c%d_acc' % (i+1) for i in range(4)])\n",
    "plt.ylabel('acc')\n",
    "plt.xlabel('epoch')\n",
    "plt.ylim(0.9, 1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 计算模型准确率"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "def evaluate(batch_size=128, steps=20):\n",
    "    batch_acc = 0\n",
    "    generator = gen(batch_size)\n",
    "    for i in tqdm(range(steps)):\n",
    "        X, y = next(generator)\n",
    "        y_pred = model.predict(X)\n",
    "        y_pred = np.argmax(y_pred, axis=-1)\n",
    "        y_true = np.argmax(y, axis=-1)\n",
    "        batch_acc += np.equal(y_true, y_pred).all(axis=0).mean()\n",
    "    return batch_acc / steps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "evaluate()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 测试模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X, y = next(gen(12))\n",
    "y_pred = model.predict(X)\n",
    "\n",
    "plt.figure(figsize=(16, 8))\n",
    "for i in range(12):\n",
    "    plt.subplot(3, 4, i+1)\n",
    "    plt.title('real: %s\\npred:%s'%(decode(y, i), decode(y_pred, i)))\n",
    "    plt.imshow(X[i])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "model.save('model.h5')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
